#include <bits/stdc++.h>

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    int n, m;
    std::cin >> n >> m;
    std::vector<std::vector<int>> e(n);
    for (int i = 0; i < n - 1; ++i) {
        int u, v;
        std::cin >> u >> v;
        --u;
        --v;
        e[u].push_back(v);
        e[v].push_back(u);
    }
    std::vector<int> a(n), b(n);
    for (int i = 0; i < m; ++i) {
        int x;
        std::cin >> x;
        a[i] = b[i] = x;
    }
    if (n == 2 && m == 2) {
        std::cout << std::abs(a[0] - a[1]) << '\n';
        return 0;
    }
    std::vector<int64_t> f(n);
    auto dp = [&](auto &&self, int u, int fa) -> void {
        if (u < m) {
            return;
        }
        for (int v : e[u]) {
            if (v != fa) {
                self(self, v, u);
            }
        }
        static std::vector<int> t;
        t.clear();
        for (int v : e[u]) {
            if (v != fa) {
                t.push_back(a[v]);
                t.push_back(b[v]);
            }
        }
        int l = t.size();
        std::nth_element(t.begin(), t.begin() + l / 2 - 1, t.end());
        a[u] = t[l / 2 - 1];
        std::nth_element(t.begin(), t.begin() + l / 2, t.end());
        b[u] = t[l / 2];
        for (int v : e[u]) {
            if (v != fa) {
                f[u] += f[v];
                if (a[u] < a[v]) {
                    f[u] += a[v] - a[u];
                } else if (a[u] > b[v]) {
                    f[u] += a[u] - b[v];
                }
            }
        }
    };
    dp(dp, m, -1);
    std::cout << f[m] << '\n';

    return 0;
}